{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLflow Prophet Tutorial\n",
    "\n",
    "This `train.pynb` Jupyter notebook predicts page views of a wikipedia page using [Prophet](https://facebook.github.io/prophet/).  \n",
    "\n",
    "> This is the Jupyter notebook version of the `train.py` example\n",
    "\n",
    "Attribution\n",
    "* The data set used in this example is from https://github.com/facebook/prophet/blob/master/examples/example_wp_log_peyton_manning.csv\n",
    "* The data was scraped from https://en.wikipedia.org/wiki/Peyton_Manning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mlflow.pyfunc\n",
    "import cloudpickle\n",
    "import fbprophet\n",
    "from fbprophet import Prophet\n",
    "\n",
    "class FbProphetWrapper(mlflow.pyfunc.PythonModel):\n",
    "\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "        super(FbProphetWrapper, self).__init__()\n",
    "\n",
    "\n",
    "    def load_context(self, context):\n",
    "        from fbprophet import Prophet\n",
    "        return\n",
    "\n",
    "    def predict(self, context, model_input):\n",
    "        future = self.model.make_future_dataframe(periods=model_input['periods'][0])\n",
    "        return self.model.predict(future)\n",
    "\n",
    "conda_env = {\n",
    "    'channels': ['defaults', 'conda-forge'],\n",
    "    'dependencies': [\n",
    "      'fbprophet={}'.format(fbprophet.__version__),\n",
    "      'cloudpickle={}'.format(cloudpickle.__version__),\n",
    "    ],\n",
    "    'name': 'fbp_env'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# page view stats\n",
    "def train(rolling_window):\n",
    "    import warnings\n",
    "    import sys\n",
    "\n",
    "    import pandas as pd\n",
    "    import numpy as np\n",
    "    # Python\n",
    "    from fbprophet import Prophet\n",
    "    from fbprophet.diagnostics import cross_validation\n",
    "    from fbprophet.diagnostics import performance_metrics\n",
    "\n",
    "    import mlflow\n",
    "    import mlflow.pyfunc\n",
    "    \n",
    "    import logging\n",
    "    logging.basicConfig(level=logging.WARN)\n",
    "    logger = logging.getLogger(__name__)\n",
    "\n",
    " \n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "    np.random.seed(40)\n",
    "\n",
    "    # Read the csv file from the URL\n",
    "    csv_url =\\\n",
    "        'https://raw.githubusercontent.com/facebook/prophet/e21a05f4f9290649255a2a306855e8b4620816d7/examples/example_wp_log_peyton_manning.csv'\n",
    "    try:\n",
    "        df = pd.read_csv(csv_url)\n",
    "    except Exception as e:\n",
    "        logger.exception(\n",
    "            \"Unable to download training & test CSV, check your internet connection. Error: %s\", e)\n",
    "\n",
    "    \n",
    "    # Useful for multiple runs (only doing one run in this sample notebook)    \n",
    "    with mlflow.start_run():\n",
    "        m = Prophet()\n",
    "        m.fit(df)\n",
    "\n",
    "        # Evaluate Metrics\n",
    "        df_cv = cross_validation(m, initial='730 days', period='180 days', horizon = '365 days')\n",
    "        df_p = performance_metrics(df_cv, rolling_window=rolling_window)\n",
    "\n",
    "        # Print out metrics\n",
    "        print(\"Prophet model (rolling_window=%f):\" % (rolling_window))\n",
    "        print(\"  CV: \\n%s\" % df_cv.head())\n",
    "        print(\"  Perf: \\n%s\" % df_p.head())\n",
    "\n",
    "        # Log parameter, metrics, and model to MLflow\n",
    "        mlflow.log_param(\"rolling_window\", rolling_window)\n",
    "        mlflow.log_metric(\"rmse\", df_p.loc[0,'rmse'])\n",
    "\n",
    "        model = FbProphetWrapper(m)\n",
    "        mlflow.pyfunc.log_model(\"model\", conda_env=conda_env, python_model=model)\n",
    "        model_uri = \"runs:/{run_id}/model\".format(run_id=mlflow.active_run().info.run_id)\n",
    "        return model_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_uri = train(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import mlflow.pyfunc\n",
    "\n",
    "model = mlflow.pyfunc.load_model(model_uri)\n",
    "\n",
    "data = {'periods': [5]}\n",
    "df = pd.DataFrame(data) \n",
    "out = model.predict(model_input=df)\n",
    "out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
